-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Track All Changes from upstream #45
base: origin
Are you sure you want to change the base?
Conversation
Implements CheckpointTensor backend
* save * save * review * review * add back files
* save * save * review * remove outdated file * save * can compile again * save * save * up up array * finsih the overloads * replace vec[n] with vec.at(n)
* commit * add more overloads * fix log * save
* save * save * save * save * use release instead of free
* save * fix comment
Overloads for mean and mean.dim
Add operator overloads for U-Net
Restore overloads needed for LSTM and GRU
* save * save * save
Various overloads for the ACT model
Overloads and changes needed for transformer
Overloads for topk functions
Add overloads for Deepspeech and more for ACT
save before change save before change add aliaspool design, and add weak/strong pointer discussion add more code rebase add allocator hook save metadata to prepare for eviction save refactor: move log to a seprate file add file raii save save comment save save save save find error, bisecting code save save address review comment address comment address comment fix segfault save save pin save save save save save
refactor - remove stitch
[ impl ] overload diag & mv
save save restore equivalentclassnode save save save 50% resnet here we go 50% resnet here we go save save save
Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { | ||
auto grad_input = at::zeros(input_sizes, grad.options()); | ||
grad_input.slice(dim, start, end, step).copy_(grad); | ||
return grad_input; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what's going on here? Did you just move these from the manual variable type definitions here? If so, we might be able to just directly upstream these hunks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is indeed true. It was manual and so I cant overload them. I exposed them to the functions.yaml file for overloading purpose.
|
||
namespace at { namespace native { | ||
|
||
Tensor checkpoint_add(const Tensor& a, const Tensor& b, c10::Scalar c) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh boy, was this all written by hand? Hopefully backend fallback can turn this into a 20 line file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great news! everytime we need a new operator, we have to edit the yaml files, then wait one hour for recompilation. writing them by hands isnt the bottleneck in comparison...
how does backend fallback deal with mutation though? we have special treatmenet for mutation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backend fallback would get the name of the operator in question, so I guess you could check if it is an inplace op or not and do something different in that case.
this is from pytorch commit 1546d2a.